{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Video Model Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTES:  \n",
    "* It's assumed that there's a pretrained generator from the ColorizeTrainingStable notebook available at the specified path.\n",
    "* This is \"NoGAN\" based training, described in the DeOldify readme."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fastai\n",
    "from fastai import *\n",
    "from fastai.vision import *\n",
    "from fastai.callbacks.tensorboard import *\n",
    "from fastai.vision.gan import *\n",
    "from fasterai.generators import *\n",
    "from fasterai.critics import *\n",
    "from fasterai.dataset import *\n",
    "from fasterai.loss import *\n",
    "from fasterai.save import *\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "from PIL import ImageFile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
    "path_hr = path\n",
    "path_lr = path/'bandw'\n",
    "\n",
    "proj_id = 'VideoModel'\n",
    "gen_name = proj_id + '_gen'\n",
    "pre_gen_name = gen_name + '_0'\n",
    "crit_name = proj_id + '_crit'\n",
    "\n",
    "name_gen = proj_id + '_image_gen'\n",
    "path_gen = path/name_gen\n",
    "\n",
    "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
    "\n",
    "nf_factor = 2\n",
    "xtra_tfms=[noisify(p=0.8)]\n",
    "pct_start = 1e-8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(bs:int, sz:int, keep_pct:float):\n",
    "    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
    "                             random_seed=None, keep_pct=keep_pct, xtra_tfms=xtra_tfms)\n",
    "\n",
    "def get_crit_data(classes, bs, sz):\n",
    "    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
    "    ll = src.label_from_folder(classes=classes)\n",
    "    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
    "           .databunch(bs=bs).normalize(imagenet_stats))\n",
    "    return data\n",
    "    \n",
    "def save_preds(dl):\n",
    "    i=0\n",
    "    names = dl.dataset.items\n",
    "    \n",
    "    for b in dl:\n",
    "        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
    "        for o in preds:\n",
    "            o.save(path_gen/names[i].name)\n",
    "            i += 1\n",
    "            \n",
    "def save_gen_images():\n",
    "    if path_gen.exists(): shutil.rmtree(path_gen)\n",
    "    path_gen.mkdir(exist_ok=True)\n",
    "    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n",
    "    save_preds(data_gen.fix_dl)\n",
    "    PIL.Image.open(path_gen.ls()[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetune Generator With Noise Augmented Images."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### This helps the generator better deal with noisy/grainy video (which is pretty normal)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=8\n",
    "sz=192\n",
    "keep_pct=0.25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen = learn_gen.load(pre_gen_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen.save(pre_gen_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Repeatable GAN Cycle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTE\n",
    "Best results so far have been based only doing a single run of the cells below (otherwise glitches are introduced that are visible in video).  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_checkpoint_num = 0\n",
    "checkpoint_num = old_checkpoint_num + 1\n",
    "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
    "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
    "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
    "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save Generated Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=8\n",
    "sz=192"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_gen_images()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pretrain Critic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=16\n",
    "sz=192"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen=None\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_critic.fit_one_cycle(4, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_critic.save(crit_new_checkpoint_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_crit=None\n",
    "learn_gen=None\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=5e-6\n",
    "sz=192\n",
    "bs=5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
    "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
    "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
    "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
    "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100, stats_iters=10, loss_iters=1))\n",
    "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Instructions:  \n",
    "Find the checkpoint just before where glitches start to be introduced.  So far this has been found at the point of iterating through 1.4% of the data when using learning rate of 1e-5, and at 2.2% of the data for 5e-6."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
    "learn_gen.freeze_to(-1)\n",
    "learn.fit(1,lr)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
